{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2021-2022 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd\n",
    "\n",
    "This code is a part of Cybertron package.\n",
    "\n",
    "The Cybertron is open-source software based on the AI-framework:\n",
    "MindSpore (https://www.mindspore.cn/)\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "\n",
    "You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "Cybertron tutorial 01: Quick introduction of Cybertron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(9802:139930080384832,MainProcess):2022-08-10-17:06:34.392.040 [mindspore/run_check/_check_version.py:137] Can not found cuda libs, please confirm that the correct cuda version has been installed, you can refer to the installation guidelines: https://www.mindspore.cn/install\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import time\n",
    "import numpy as np\n",
    "from mindspore import nn\n",
    "from mindspore import context\n",
    "from mindspore import dataset as ds\n",
    "from mindspore.train import Model\n",
    "from mindspore.train.callback import LossMonitor\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "from cybertron import Cybertron\n",
    "from cybertron.train import WithLabelLossCell\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file = sys.path[0] + '/dataset_qm9_origin_trainset_1024.npz'\n",
    "train_data = np.load(train_file)\n",
    "\n",
    "idx = [0]  # diple\n",
    "num_atom = int(train_data['num_atoms'])\n",
    "scale = train_data['scale'][idx]\n",
    "shift = train_data['shift'][idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Cybertron Engine, Ride-on!\n",
      "--------------------------------------------------------------------------------\n",
      "    Length unit: nm\n",
      "    Input unit scale: 1\n",
      "--------------------------------------------------------------------------------\n",
      "    Deep molecular model:  SchNet\n",
      "--------------------------------------------------------------------------------\n",
      "       Length unit: nm\n",
      "       Atom embedding size: 64\n",
      "       Cutoff distance: 1.0 nm\n",
      "       Radical basis function (RBF): GaussianBasis\n",
      "          Minimum distance: 0.0 nm\n",
      "          Maximum distance: 1.0 nm\n",
      "          Sigma for Gaussian: 0.03 nm\n",
      "          Interval for Gaussian: 0.016 nm\n",
      "          Number of basis functions: 64\n",
      "       Calculate distance: Yes\n",
      "       Calculate bond: No\n",
      "       Feature dimension: 64\n",
      "--------------------------------------------------------------------------------\n",
      "       Using 3 independent interaction layers:\n",
      "--------------------------------------------------------------------------------\n",
      "       0. SchNet Interaction Layer\n",
      "          Feature dimension: 64\n",
      "          Activation function: ShiftedSoftplus\n",
      "          Dimension for filter network: 64\n",
      "--------------------------------------------------------------------------------\n",
      "       1. SchNet Interaction Layer\n",
      "          Feature dimension: 64\n",
      "          Activation function: ShiftedSoftplus\n",
      "          Dimension for filter network: 64\n",
      "--------------------------------------------------------------------------------\n",
      "       2. SchNet Interaction Layer\n",
      "          Feature dimension: 64\n",
      "          Activation function: ShiftedSoftplus\n",
      "          Dimension for filter network: 64\n",
      "--------------------------------------------------------------------------------\n",
      "    Readout network: GraphReadout\n",
      "--------------------------------------------------------------------------------\n",
      "       Activation function: ShiftedSoftplus\n",
      "       Decoder: HalveDecoder\n",
      "       Aggregator: TensorMean\n",
      "       Representation dimension: 64\n",
      "       Readout dimension: 1\n",
      "       Scale: [1.5031829]\n",
      "       Shift: [2.6728272]\n",
      "       Scaleshift mode: Graph\n",
      "       Reference value for atom types: None\n",
      "       Output unit: kJ mol-1\n",
      "       Reduce axis: -2\n",
      "--------------------------------------------------------------------------------\n",
      "    Output dimension: 1\n",
      "    Output unit for Cybertron: kJ mol-1\n",
      "    Output unit scale: 1.0\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "net = Cybertron(model='schnet', readout='graph', dim_output=1,\n",
    "                num_atoms=num_atom, length_unit='nm', energy_unit='kj/mol')\n",
    "net.set_scaleshift(scale, shift)\n",
    "net.print_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 model.atom_embedding.embedding_table (64, 64)\n",
      "1 model.interactions.0.atomwise_bc.weight (64, 64)\n",
      "2 model.interactions.0.atomwise_bc.bias (64,)\n",
      "3 model.interactions.0.atomwise_ac.mlp.0.weight (64, 64)\n",
      "4 model.interactions.0.atomwise_ac.mlp.0.bias (64,)\n",
      "5 model.interactions.0.atomwise_ac.mlp.1.weight (64, 64)\n",
      "6 model.interactions.0.atomwise_ac.mlp.1.bias (64,)\n",
      "7 model.interactions.0.dis_filter.dense_layers.mlp.0.weight (64, 64)\n",
      "8 model.interactions.0.dis_filter.dense_layers.mlp.0.bias (64,)\n",
      "9 model.interactions.0.dis_filter.dense_layers.mlp.1.weight (64, 64)\n",
      "10 model.interactions.0.dis_filter.dense_layers.mlp.1.bias (64,)\n",
      "11 model.interactions.1.atomwise_bc.weight (64, 64)\n",
      "12 model.interactions.1.atomwise_bc.bias (64,)\n",
      "13 model.interactions.1.atomwise_ac.mlp.0.weight (64, 64)\n",
      "14 model.interactions.1.atomwise_ac.mlp.0.bias (64,)\n",
      "15 model.interactions.1.atomwise_ac.mlp.1.weight (64, 64)\n",
      "16 model.interactions.1.atomwise_ac.mlp.1.bias (64,)\n",
      "17 model.interactions.1.dis_filter.dense_layers.mlp.0.weight (64, 64)\n",
      "18 model.interactions.1.dis_filter.dense_layers.mlp.0.bias (64,)\n",
      "19 model.interactions.1.dis_filter.dense_layers.mlp.1.weight (64, 64)\n",
      "20 model.interactions.1.dis_filter.dense_layers.mlp.1.bias (64,)\n",
      "21 model.interactions.2.atomwise_bc.weight (64, 64)\n",
      "22 model.interactions.2.atomwise_bc.bias (64,)\n",
      "23 model.interactions.2.atomwise_ac.mlp.0.weight (64, 64)\n",
      "24 model.interactions.2.atomwise_ac.mlp.0.bias (64,)\n",
      "25 model.interactions.2.atomwise_ac.mlp.1.weight (64, 64)\n",
      "26 model.interactions.2.atomwise_ac.mlp.1.bias (64,)\n",
      "27 model.interactions.2.dis_filter.dense_layers.mlp.0.weight (64, 64)\n",
      "28 model.interactions.2.dis_filter.dense_layers.mlp.0.bias (64,)\n",
      "29 model.interactions.2.dis_filter.dense_layers.mlp.1.weight (64, 64)\n",
      "30 model.interactions.2.dis_filter.dense_layers.mlp.1.bias (64,)\n",
      "31 readout.decoder.output.mlp.0.weight (32, 64)\n",
      "32 readout.decoder.output.mlp.0.bias (32,)\n",
      "33 readout.decoder.output.mlp.1.weight (1, 32)\n",
      "34 readout.decoder.output.mlp.1.bias (1,)\n",
      "Total parameters:  68609\n"
     ]
    }
   ],
   "source": [
    "tot_params = 0\n",
    "for i, param in enumerate(net.get_parameters()):\n",
    "    tot_params += param.size\n",
    "    print(i, param.name, param.shape)\n",
    "print('Total parameters: ', tot_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WithLabelLossCell with input type: RZE\n"
     ]
    }
   ],
   "source": [
    "n_epoch = 8\n",
    "repeat_time = 1\n",
    "batch_size = 32\n",
    "\n",
    "ds_train = ds.NumpySlicesDataset(\n",
    "    {'R': train_data['R'], 'Z': train_data['Z'], 'E': train_data['E'][:, idx]}, shuffle=True)\n",
    "ds_train = ds_train.batch(batch_size, drop_remainder=True)\n",
    "ds_train = ds_train.repeat(repeat_time)\n",
    "loss_network = WithLabelLossCell('RZE', net, nn.MAELoss())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "optim = nn.Adam(params=net.trainable_params(), learning_rate=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(loss_network, optimizer=optim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "monitor_cb = LossMonitor(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = 'Tutorial_C01'\n",
    "params_name = outdir + '_' + net.model_name\n",
    "config_ck = CheckpointConfig(\n",
    "    save_checkpoint_steps=32, keep_checkpoint_max=64, append_info=[net.hyper_param])\n",
    "ckpoint_cb = ModelCheckpoint(\n",
    "    prefix=params_name, directory=outdir, config=config_ck)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start training ...\n",
      "epoch: 1 step: 16, loss is 1.2756918668746948\n",
      "epoch: 1 step: 32, loss is 0.992481529712677\n",
      "epoch: 2 step: 16, loss is 1.0690456628799438\n",
      "epoch: 2 step: 32, loss is 0.994649350643158\n",
      "epoch: 3 step: 16, loss is 0.8286229372024536\n",
      "epoch: 3 step: 32, loss is 0.7207255363464355\n",
      "epoch: 4 step: 16, loss is 0.8941224217414856\n",
      "epoch: 4 step: 32, loss is 0.7824708819389343\n",
      "epoch: 5 step: 16, loss is 0.6775331497192383\n",
      "epoch: 5 step: 32, loss is 0.8148207664489746\n",
      "epoch: 6 step: 16, loss is 0.8503636121749878\n",
      "epoch: 6 step: 32, loss is 0.6760039329528809\n",
      "epoch: 7 step: 16, loss is 1.0324618816375732\n",
      "epoch: 7 step: 32, loss is 0.7412121891975403\n",
      "epoch: 8 step: 16, loss is 1.0068881511688232\n",
      "epoch: 8 step: 32, loss is 0.6990125179290771\n",
      "Training Fininshed!\n",
      "Training Time: 00:00:15\n"
     ]
    }
   ],
   "source": [
    "print(\"Start training ...\")\n",
    "beg_time = time.time()\n",
    "model.train(n_epoch, ds_train, callbacks=[\n",
    "    monitor_cb, ckpoint_cb], dataset_sink_mode=False)\n",
    "end_time = time.time()\n",
    "used_time = end_time - beg_time\n",
    "m, s = divmod(used_time, 60)\n",
    "h, m = divmod(m, 60)\n",
    "print(\"Training Fininshed!\")\n",
    "print(\"Training Time: %02d:%02d:%02d\" % (h, m, s))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 ('mindsponge')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2496ecc683137a232cae2452fbbdd53dab340598b6e499c8995be760f3a431b4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
